package org.zjvis.datascience.spark.algorithm;

import org.alitouka.spark.dbscan.Dbscan;
import org.alitouka.spark.dbscan.DbscanModel;
import org.alitouka.spark.dbscan.DbscanSettings;
import org.alitouka.spark.dbscan.spatial.Point;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;
import scala.Tuple2;
import scala.collection.JavaConversions;
import scala.collection.mutable.WrappedArray.ofDouble;

import java.util.*;

/**
 * @description Spark-DBScan算子 [alitouka 开源版本]
 * @date 2021-12-23
 */
public class DBscanAlgorithmV2 extends BaseAlgorithm {
    private double eps;

    private int minPoints;

    private int maxPoints;

    private String[] featureCols;

    private String targetTable;

    private String sourceTable;

    public DBscanAlgorithmV2(SparkSession sparkSession) {
        super(sparkSession);
    }

    public boolean parseParams(String[] args) {
        super.parseParams(args);
        options.addOption(OptionHelper.getSpecOption("source"));
        options.addOption(OptionHelper.getSpecOption("target"));
        options.addOption(OptionHelper.getSpecOption("idCol"));
        options.addOption(OptionHelper.getSpecOption("featureCols"));
        options.addOption(OptionHelper.getSpecOption("eps", "eps", "eps", true));
        options.addOption(OptionHelper.getSpecOption("minp", "minPoints", "min points", true));
        options.addOption(OptionHelper.getSpecOption("maxp", "maxPoints", "max points", true));

        boolean flag = this.initCommandLine(args);

        if (!flag) {
            logger.error("initCommandLine fail!!!");
            return false;
        }

        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("idcol") || cmd.hasOption("idCol")) {
            idCol = cmd.getOptionValue("idCol", UtilTool.DEFAULT_ID_COL);
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }
        if (cmd.hasOption("eps")) {
            eps = Double.parseDouble(cmd.getOptionValue("eps"));
        }
        if (cmd.hasOption("minp") || cmd.hasOption("minPoints")) {
            minPoints = Integer.parseInt(cmd.getOptionValue("minPoints"));
        }
        if (cmd.hasOption("maxp") || cmd.hasOption("maxPoints")) {
            maxPoints = Integer.parseInt(cmd.getOptionValue("maxPoints"));
        }
        return true;
    }

    private Point createMultiDPoint(Long id, DenseVector vector) {
        double[] array = vector.toArray();
        double sum = 0.0;
        for (double s : array) {
            sum += s * s;
        }
        return new Point(new ofDouble(array), id, 1, Math.sqrt(sum), 0, -2);
    }

    public boolean beginAlgorithm() {
        Dataset<Row> dataset;
        if (isHive) {
            String sql = UtilTool.buildSelectSql(featureCols, sourceTable, idCol);
            dataset = sparkSession.sql(sql);
        } else {
            if (isSample) {
                String cacheTable = cacheUtil.modifyCacheTableName(sourceTable);
                if (cacheUtil.isCacheTableExists(cacheTable)) {
                    dataset = sparkSession.table(cacheTable)
                            .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
                } else {
                    Dataset<Row> cacheRdd = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol, sampleNumber);
                    dataset = cacheRdd.select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
                    if (!cacheUtil.cacheTableForDataset(cacheRdd, cacheTable)) {
                        logger.error("table = {} cache fail!!!", cacheTable);
                        return false;
                    }
                    logger.info("table = {} cache success", cacheTable);
                }
            } else {
                dataset = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol)
                        .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
            }
        }

        Map<String, DataType> schema = UtilTool.getDataTypeMaps(dataset.schema(), true);
        Map<String, Tuple2<DataType, DataType>> needChangeTypes = UtilTool.getNeedChangeDoubleTypeKeys(featureCols, schema);

        VectorAssembler assembler = new VectorAssembler()
                .setHandleInvalid("skip")
                .setInputCols(featureCols)
                .setOutputCol("features");
        Dataset<Row> transDF = assembler.transform(dataset).select(idCol, "features");
        // .persist(StorageLevel.MEMORY_AND_DISK());

//        String path = "hdfs:///zjvis/seed_ids.csv";
//        RDD<Point> pointRDD =  sparkSession.sparkContext().textFile(path, 2).toJavaRDD().map(x-> {
//            String[] tmps = x.split(",");
//            Long id = Long.parseLong(tmps[0]);
//            double[] array = new double[tmps.length - 1];
//            double sum = 0.0;
//            for (int i = 1; i < tmps.length; ++i) {
//                sum += Double.parseDouble(tmps[i]) * Double.parseDouble(tmps[i]);
//                array[i - 1] = Double.parseDouble(tmps[i]);
//            }
//            Point point = new Point(new ofDouble(array), id, 1, Math.sqrt(sum), 0, -2);
//            return point;
//        }).rdd();


//        RDD<Point> pointRDD = transDF.map((MapFunction<Row, Point>) row ->
//                createMultiDPoint(row.getLong(0), (DenseVector) row.get(1)), Encoders.bean(Point.class)).coalesce(1).rdd();

        RDD<Point> pointRDD = transDF.toJavaRDD()
                .map(row -> createMultiDPoint(Long.valueOf(String.valueOf(row.getInt(0))), (DenseVector) row.get(1)))
                .rdd().persist(StorageLevel.MEMORY_AND_DISK());

//        pointRDD.toJavaRDD().map(Point::toString).saveAsTextFile("hdfs:///zjvis/wyz/point");

        DbscanSettings dbscanSettings = new DbscanSettings()
                .withEpsilon(eps)
                .withNumberOfPoints(minPoints);

        DbscanModel model = Dbscan.train(pointRDD, dbscanSettings);

        JavaRDD<Row> rdd = model.allPoints().toJavaRDD().map(point ->
                {
                    List<Object> features = new ArrayList<>();
                    features.add(Integer.valueOf(String.valueOf(point.pointId())));
                    double[] array = point.coordinates().array();
                    for (double item : array) {
                        features.add(item);
                    }
                    features.add(String.valueOf(point.clusterId()));
                    return RowFactory.create(features.toArray());
                }
        );

        Map<String, DataType> map = new HashMap<>();
        map.put("cluster_id", DataTypes.StringType);
        List<String> keys = new ArrayList<>();
        keys.add(idCol);
        keys.addAll(Arrays.asList(featureCols));
        StructType newScheam = UtilTool.createSchema(schema, keys, 0, map);

        Dataset<Row> df = sparkSession.createDataFrame(rdd, newScheam);

        df = UtilTool.changeBackDataTypeForKey(df, needChangeTypes);

        if (isHive) {
            this.saveHiveTable(df, targetTable);
        } else {
            UtilTool.saveGreenplumTable(df, targetTable);
        }

        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        outputTables.add(targetTable);

        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);

        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);

        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        pointRDD.unpersist(false);
        return true;
    }
}